{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import absolute_import\n",
    "from __future__ import division\n",
    "from __future__ import print_function\n",
    "\n",
    "import argparse\n",
    "import sys\n",
    "\n",
    "from tensorflow.examples.tutorials.mnist import input_data\n",
    "\n",
    "import tensorflow as tf\n",
    "\n",
    "FLAGS = None"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "导入数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From <ipython-input-2-cf0ad10d35ed>:3: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use alternatives such as official/mnist/dataset.py from tensorflow/models.\n",
      "WARNING:tensorflow:From /home/sika/miniconda3/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please write your own downloading logic.\n",
      "WARNING:tensorflow:From /home/sika/miniconda3/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use tf.data to implement this functionality.\n",
      "Extracting MNIST_data/train-images-idx3-ubyte.gz\n",
      "WARNING:tensorflow:From /home/sika/miniconda3/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use tf.data to implement this functionality.\n",
      "Extracting MNIST_data/train-labels-idx1-ubyte.gz\n",
      "WARNING:tensorflow:From /home/sika/miniconda3/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:110: dense_to_one_hot (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use tf.one_hot on tensors.\n",
      "Extracting MNIST_data/t10k-images-idx3-ubyte.gz\n",
      "Extracting MNIST_data/t10k-labels-idx1-ubyte.gz\n",
      "WARNING:tensorflow:From /home/sika/miniconda3/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use alternatives such as official/mnist/dataset.py from tensorflow/models.\n"
     ]
    }
   ],
   "source": [
    "# Import data\n",
    "data_dir = 'MNIST_data/'\n",
    "mnist = input_data.read_data_sets(data_dir, one_hot=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 超参数\n",
    "learning_rate = 0.5\n",
    "\n",
    "# placeholder\n",
    "# 输入图片为28 x 28 像素 = 784\n",
    "x = tf.placeholder(tf.float32, [None, 784])\n",
    "# 输出为0-9的one-hot编码\n",
    "y = tf.placeholder(tf.float32, [None, 10])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "学习率设置的太小会导致训练太慢"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "#先定义一个隐层\n",
    "w1=tf.Variable(tf.random_normal([784,300],stddev=0.03),name=\"W1\")\n",
    "b1=tf.Variable(tf.random_normal([300]),name='b1')\n",
    "#定义输出层\n",
    "w2=tf.Variable(tf.random_normal([300,10],stddev=0.03),name=\"W2\")\n",
    "b2=tf.Variable(tf.random_normal([10]),name='b2')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "隐层公式\n",
    "z=wx+b\n",
    "h=relu(z)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "hidden_out = tf.add(tf.matmul(x, w1), b1)\n",
    "hidden_out = tf.nn.relu(hidden_out)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "tensorflow调用Softmax"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "hypothesis = tf.nn.softmax(tf.add(tf.matmul(hidden_out, w2), b2))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "计算交叉熵"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_clipped = tf.clip_by_value(hypothesis, 1e-10, 0.9999999)\n",
    "cross_entropy = -tf.reduce_mean(tf.reduce_sum(y * tf.log(y_clipped) + (1 - y) * tf.log(1 - y_clipped), axis=1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate).minimize(cross_entropy)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "定义初始化operation和准确率node"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# init operator\n",
    "init_op = tf.global_variables_initializer()\n",
    "# 创建准确率节点\n",
    "correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(hypothesis, 1))\n",
    "accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "开始训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 0.1111\n",
      "200 0.921\n",
      "400 0.9457\n",
      "600 0.9541\n",
      "800 0.9532\n",
      "1000 0.9654\n",
      "1200 0.9717\n",
      "1400 0.9716\n",
      "1600 0.9711\n",
      "1800 0.9733\n",
      "2000 0.972\n",
      "2200 0.9748\n",
      "2400 0.974\n",
      "2600 0.974\n",
      "2800 0.9762\n",
      "3000 0.9782\n",
      "3200 0.9766\n",
      "3400 0.975\n",
      "3600 0.9772\n",
      "3800 0.9761\n",
      "4000 0.9766\n",
      "4200 0.9775\n",
      "4400 0.9767\n",
      "4600 0.9788\n",
      "4800 0.9762\n",
      "5000 0.9765\n",
      "5200 0.9775\n",
      "5400 0.9792\n",
      "5600 0.9781\n",
      "5800 0.9788\n",
      "6000 0.9799\n",
      "6200 0.9808\n"
     ]
    }
   ],
   "source": [
    "sess=tf.Session()\n",
    "sess.run(init_op)\n",
    "for i in range(6400):\n",
    "    batch_x, batch_y = mnist.train.next_batch(batch_size=100)\n",
    "    sess.run([optimizer, accuracy], feed_dict={x: batch_x, y: batch_y})\n",
    "    if i%200 ==0:\n",
    "        print(i,sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels}))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "训练差不多6400次就可以满足要求了。"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
